! (C) Copyright 2000- ECMWF.
!
! This software is licensed under the terms of the Apache Licence Version 2.0
! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
! In applying this licence, ECMWF does not waive the privileges and immunities
! granted to it by virtue of its status as an intergovernmental organisation
! nor does it submit to any jurisdiction.
!

#if defined CUDAGPU
#define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM
#define OPENACC_LIB OPENACC
#endif

MODULE HICBLAS_MOD

USE EC_PARKIND, ONLY: JPIM, JPRM, JPRD
USE GROWING_ALLOCATOR_MOD, ONLY: GROWING_ALLOCATION_TYPE
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM

IMPLICIT NONE

INTERFACE
  SUBROUTINE HIP_DGEMM_BATCHED( &
    & CTA, CTB,                 &
    & M, N, K,                  &
    & ALPHA,                    &
    & A, LDA, TDA,              &
    & B, LDB, TDB,              &
    & BETA,                     &
    & C, LDC, TDC,              &
    & BATCHCOUNT, STREAM, ALLOC &
  &) BIND(C, NAME='hipblas_dgemm_wrapper')
    USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_DOUBLE, C_SIZE_T, C_PTR
    CHARACTER(1,C_CHAR), VALUE            :: CTA, CTB
    INTEGER(C_INT),      VALUE            :: M, N, K, LDA, LDB, LDC, TDA, TDB, TDC, BATCHCOUNT
    REAL(C_DOUBLE),      VALUE            :: ALPHA,BETA
    REAL(C_DOUBLE),      DIMENSION(LDA,*) :: A
    REAL(C_DOUBLE),      DIMENSION(LDB,*) :: B
    REAL(C_DOUBLE),      DIMENSION(LDC,*) :: C
    INTEGER(KIND=C_SIZE_T) :: STREAM
    TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
  END SUBROUTINE HIP_DGEMM_BATCHED

  SUBROUTINE HIP_SGEMM_BATCHED( &
    & CTA, CTB,                 &
    & M, N, K,                  &
    & ALPHA,                    &
    & A, LDA, TDA,              &
    & B, LDB, TDB,              &
    & BETA,                     &
    & C, LDC, TDC,              &
    & BATCHCOUNT, STREAM, ALLOC &
  &) BIND(C, NAME='hipblas_sgemm_wrapper')
    USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_FLOAT, C_SIZE_T, C_PTR
    CHARACTER(1,C_CHAR), VALUE            :: CTA, CTB
    INTEGER(C_INT),      VALUE            :: M, N, K, LDA, LDB, LDC, TDA, TDB, TDC, BATCHCOUNT
    REAL(C_FLOAT),       VALUE            :: ALPHA, BETA
    REAL(C_FLOAT),       DIMENSION(LDA,*) :: A
    REAL(C_FLOAT),       DIMENSION(LDB,*) :: B
    REAL(C_FLOAT),       DIMENSION(LDC,*) :: C
    INTEGER(KIND=C_SIZE_T) :: STREAM
    TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
  END SUBROUTINE HIP_SGEMM_BATCHED
END INTERFACE

INTERFACE
  SUBROUTINE HIP_DGEMM_GROUPED( &
    & BLAS_ID, CTA, CTB,        &
    & M, N, K,                  &
    & ALPHA,                    &
    & A, LDA, OFFSETA,          &
    & B, LDB, OFFSETB,          &
    & BETA,                     &
    & C, LDC, OFFSETC,          &
    & BATCHCOUNT, STREAM, ALLOC &
  &) BIND(C, NAME='hipblas_dgemm_wrapper_grouped')
    USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_DOUBLE, C_SIZE_T, C_PTR
    CHARACTER(1,C_CHAR), VALUE :: CTA, CTB
    INTEGER(C_INT), VALUE  :: BLAS_ID, M, LDA, LDB, LDC, BATCHCOUNT
    INTEGER(C_INT)         :: N(*), K(*), OFFSETA(*), OFFSETB(*), OFFSETC(*)
    REAL(C_DOUBLE), VALUE  :: ALPHA,BETA
    REAL(C_DOUBLE)         :: A(*), B(*), C(*)
    INTEGER(KIND=C_SIZE_T) :: STREAM
    TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
  END SUBROUTINE HIP_DGEMM_GROUPED

  SUBROUTINE HIP_SGEMM_GROUPED( &
    & BLAS_ID, CTA, CTB,        &
    & M, N, K,                  &
    & ALPHA,                    &
    & A, LDA, OFFSETA,          &
    & B, LDB, OFFSETB,          &
    & BETA,                     &
    & C, LDC, OFFSETC,          &
    & BATCHCOUNT, STREAM, ALLOC &
  &) BIND(C, NAME='hipblas_sgemm_wrapper_grouped')
    USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_FLOAT, C_SIZE_T, C_PTR
    CHARACTER(1,C_CHAR), VALUE :: CTA, CTB
    INTEGER(C_INT), VALUE :: BLAS_ID, M, LDA, LDB, LDC, BATCHCOUNT
    INTEGER(C_INT)        :: N(*), K(*), OFFSETA(*), OFFSETB(*), OFFSETC(*)
    REAL(C_FLOAT), VALUE  :: ALPHA,BETA
    REAL(C_FLOAT)         :: A(*), B(*), C(*)
    INTEGER(KIND=C_SIZE_T) :: STREAM
    TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
  END SUBROUTINE HIP_SGEMM_GROUPED
END INTERFACE

CONTAINS

SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD( &
  & TRANSA, TRANSB, &
  & M, N, K, &
  & ALPHA, &
  & AARRAY, LDA, STRIDEA, &
  & BARRAY, LDB, STRIDEB, &
  & BETA, &
  & CARRAY, LDC, STRIDEC, &
  & BATCHCOUNT, STREAM, ALLOC)
  USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
  CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
  INTEGER(KIND=JPIM) :: M
  INTEGER(KIND=JPIM) :: N
  INTEGER(KIND=JPIM) :: K
  REAL(KIND=JPRD) :: ALPHA
  REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
  INTEGER(KIND=JPIM) :: LDA
  INTEGER(KIND=JPIM) :: STRIDEA
  REAL(KIND=JPRD), DIMENSION(:,:) :: BARRAY
  INTEGER(KIND=JPIM) :: LDB
  INTEGER(KIND=JPIM) :: STRIDEB
  REAL(KIND=JPRD) :: BETA
  REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
  INTEGER(KIND=JPIM) :: LDC
  INTEGER(KIND=JPIM) :: STRIDEC
  INTEGER(KIND=JPIM) :: BATCHCOUNT
  INTEGER(KIND=C_INT) :: STREAM
  TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

  INTEGER(KIND=C_LONG) :: HIP_STREAM

  HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

#if defined(_CRAYFTN)
  !$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
#endif
  CALL HIP_DGEMM_BATCHED( &
    & TRANSA, TRANSB, &
    & M, N, K, &
    & ALPHA, &
    & AARRAY, LDA, STRIDEA, &
    & BARRAY, LDB, STRIDEB, &
    & BETA, &
    & CARRAY, LDC, STRIDEC, &
    & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
#if defined(_CRAYFTN)
  !$ACC END HOST_DATA
#endif
END SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD

SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD( &
  & TRANSA, TRANSB, &
  & M, N, K, &
  & ALPHA, &
  & AARRAY, LDA, STRIDEA, &
  & BARRAY, LDB, STRIDEB, &
  & BETA, &
  & CARRAY, LDC, STRIDEC, &
  & BATCHCOUNT, STREAM, ALLOC)
  USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
  CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
  INTEGER(KIND=JPIM) :: M
  INTEGER(KIND=JPIM) :: N
  INTEGER(KIND=JPIM) :: K
  REAL(KIND=JPRM) :: ALPHA
  REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
  INTEGER(KIND=JPIM) :: LDA
  INTEGER(KIND=JPIM) :: STRIDEA
  REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
  INTEGER(KIND=JPIM) :: LDB
  INTEGER(KIND=JPIM) :: STRIDEB
  REAL(KIND=JPRM) :: BETA
  REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
  INTEGER(KIND=JPIM) :: LDC
  INTEGER(KIND=JPIM) :: STRIDEC
  INTEGER(KIND=JPIM) :: BATCHCOUNT
  INTEGER(KIND=C_INT) :: STREAM
  TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

  INTEGER(KIND=C_LONG) :: HIP_STREAM

  HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

  CALL HIP_SGEMM_BATCHED( &
    & TRANSA, TRANSB, &
    & M, N, K, &
    & ALPHA, &
    & AARRAY, LDA, STRIDEA, &
    & BARRAY, LDB, STRIDEB, &
    & BETA, &
    & CARRAY, LDC, STRIDEC, &
    & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
END SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD

SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( &
  & BLAS_ID, TRANSA, TRANSB, &
  & M, N, K, &
  & ALPHA, &
  & AARRAY, LDA, OFFSETA, &
  & BARRAY, LDB, OFFSETB, &
  & BETA, &
  & CARRAY, LDC, OFFSETC, &
  & BATCHCOUNT, STREAM, ALLOC)
  USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
  INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
  CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
  INTEGER(KIND=JPIM) :: M
  INTEGER(KIND=JPIM) :: N(:)
  INTEGER(KIND=JPIM) :: K(:)
  REAL(KIND=JPRD) :: ALPHA
  REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
  INTEGER(KIND=JPIM) :: LDA
  INTEGER(KIND=JPIM) :: OFFSETA(:)
  REAL(KIND=JPRD), DIMENSION(*) :: BARRAY
  INTEGER(KIND=JPIM) :: LDB
  INTEGER(KIND=JPIM) :: OFFSETB(:)
  REAL(KIND=JPRD) :: BETA
  REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
  INTEGER(KIND=JPIM) :: LDC
  INTEGER(KIND=JPIM) :: OFFSETC(:)
  INTEGER(KIND=JPIM) :: BATCHCOUNT
  INTEGER(KIND=C_INT) :: STREAM
  TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

  INTEGER(KIND=C_LONG) :: HIP_STREAM

  HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

  CALL HIP_DGEMM_GROUPED( &
    & BLAS_ID, TRANSA, TRANSB, &
    & M, N, K, &
    & ALPHA, &
    & AARRAY, LDA, OFFSETA, &
    & BARRAY, LDB, OFFSETB, &
    & BETA, &
    & CARRAY, LDC, OFFSETC, &
    & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))

END SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD

SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(&
  & BLAS_ID, TRANSA, TRANSB, &
  & M, N, K, &
  & ALPHA, &
  & AARRAY, LDA, OFFSETA, &
  & BARRAY, LDB, OFFSETB, &
  & BETA, &
  & CARRAY, LDC, OFFSETC, &
  & BATCHCOUNT, STREAM, ALLOC)
  USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
  INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
  CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
  INTEGER(KIND=JPIM) :: M
  INTEGER(KIND=JPIM) :: N(:)
  INTEGER(KIND=JPIM) :: K(:)
  REAL(KIND=JPRM) :: ALPHA
  REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
  INTEGER(KIND=JPIM) :: LDA
  INTEGER(KIND=JPIM) :: OFFSETA(:)
  REAL(KIND=JPRM), DIMENSION(:,:,:) :: BARRAY
  INTEGER(KIND=JPIM) :: LDB
  INTEGER(KIND=JPIM) :: OFFSETB(:)
  REAL(KIND=JPRM) :: BETA
  REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
  INTEGER(KIND=JPIM) :: LDC
  INTEGER(KIND=JPIM) :: OFFSETC(:)
  INTEGER(KIND=JPIM) :: BATCHCOUNT
  INTEGER(KIND=C_INT) :: STREAM
  TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

  INTEGER(KIND=C_LONG) :: HIP_STREAM

  HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

#if defined(_CRAYFTN)
  !$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
#endif
  CALL HIP_SGEMM_GROUPED( &
    & BLAS_ID, TRANSA, TRANSB, &
    & M, N, K, &
    & ALPHA, &
    & AARRAY, LDA, OFFSETA, &
    & BARRAY, LDB, OFFSETB, &
    & BETA, &
    & CARRAY, LDC, OFFSETC, &
    & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
#if defined(_CRAYFTN)
  !$ACC END HOST_DATA
#endif

END SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD

END MODULE HICBLAS_MOD
